-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MambaForSequenceClassification #29552
Add MambaForSequenceClassification #29552
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mjschook, thanks for opening this PR and adding this!
AFAICT, these changes seem reasonable to me. @ArthurZucker is off for a week - so let's wait for him to come back to confirm if there's any reason for not adding this to Mamba.
A few things that will need to be added:
- Tests for the model i.e. equivalent to
create_and_check_mamba_model
and the model should be added toall_model_classes
- The model needs to be documented in
mamba.md
- All the tests in the CI should be passing
x = features[:, 0, :] # take <s> token (equiv. to [CLS]) | ||
x = self.dropout(x) | ||
x = self.dense(x) | ||
x = ACT2FN[self.config.hidden_act](x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The activation layer should be set in the init and then called here i.e.
def __init__(...):
...
self.activation = ACT2FN[config.hidden_act]
def forward(...):
...
x = self.activation(x)
self.classifier = MambaClassificationHead(config) | ||
|
||
for param in self.base_model.parameters(): | ||
param.requires_grad = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure whether we actually want to freeze the params for the base model here, but I found there's a test specific to sequence classification that expects all the unfrozen params to be initialized in the range [0.0, 1.0]
and the initialized values for the mixer don't satisfy that assertion.
So... I froze them and made sure the classification head params were initialized to satisfiy the test. It makes intuitive sense to me to freeze them in the case of transfer learning for this task and I did confirm that running LoRA PEFT with target_modules=["x_proj", "embeddings", "in_proj", "out_proj"]
and task_type=TaskType.SEQ_CLS
does unfreeze the target modules so it appears to work fine, but not sure if we want to force them to be frozen by default.
Anyway, happy to adjust if there's a better practice to follow here.
…prepare_config_and_inputs
…classifier head linear layer weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! We usually do the following before merging such PRs:
- Create a feature request issue to add this new class
- Wait until the community picks this up
- Wait until there are actually pretrained checkpoints released by the community or the authors.
As is, this does not really help anyone as it can be easily implemented by anyone that wants to train a model no?
Thanks for the feedback @ArthurZucker - I'll close it since I went another direction, using prompt tuning instead. I'll keep the process you laid out in mind for the future. =) |
Would be helpful to have this class. Looking forward to #30431 |
What does this PR do?
Adds
MambaForSequenceClassification
for sequence classification with theMambaModel
backbone.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker thanks for your work bringing in Mamba! I'm wondering if there's any objection to adding a MambaForSequenceClassification model? I followed the template example as best as I could and happy to continue with adding the new test and finishing it up.